PyTorch到Tensorflow的模型转换

您所在的位置:网站首页 tensorflow js权重 PyTorch到Tensorflow的模型转换

PyTorch到Tensorflow的模型转换

2024-07-10 14:32| 来源: 网络整理| 查看: 265

文章目录 1 结果相同,使用ONNX的不同框架2 PyTorch模型转换pipeline3 转移模型结果

在本文中,我们将学习如何将PyTorch模型转换为TensorFlow。

如果您是深度学习的新手,那么使用哪种框架可能会让您不知所措。我们个人认为PyTorch是您应该学习的第一个框架,但它可能不是您想要学习的唯一框架。

好消息是您不需要嫁给框架。您可以在PyTorch中训练模型,然后将其轻松转换为Tensorflow,只要使用标准图层即可。实现此转换的最佳方法是先将PyTorch模型转换为ONNX,然后再转换为Tensorflow/Keras格式。

1 结果相同,使用ONNX的不同框架

我们可以观察到,在有关FCN ResNet-18 PyTorch的早期文章中,所实现的模型比TensorFlow FCN版本更准确地预测了图片中的单峰骆驼区域:

PyTorch FCN ResNet18激活: TensorFlow FCN ResNet50激活: 假设我们想捕获结果并将它们转移到另一个领域,例如,从PyTorch到TensorFlow。有什么办法可以执行吗?答案是肯定的。一种可能的方法是使用pytorch2keras库。如其名称所述,该工具提供了在PyTorch和Keras等框架之间进行模型转换的简便方法。您可以使用pip轻松安装它:

pip3 install pytorch2keras 2 PyTorch模型转换pipeline

从pytorch2keras的仓库中可以看到,converter.py中描述了 pipeline 的逻辑。让我们查看其关键点:

def pytorch_to_keras( model, args, input_shapes=None, change_ordering=False, verbose=False, name_policy=None, use_optimizer=False, do_constant_folding=False ): # ... # load a ModelProto structure with ONNX onnx_model = onnx.load(stream) # ... # k_model = onnx_to_keras(onnx_model=onnx_model, input_names=input_names, input_shapes=input_shapes, name_policy=name_policy, verbose=verbose, change_ordering=change_ordering) return k_model

您可能已经注意到,该工具基于Open Neural Network Exchange (ONNX)。ONNX是一个开源AI项目,其目标是使不同工具之间的神经网络模型互换,以选择这些工具的更好组合。将获得的过渡的top-level ONNX ModelProto容器传递给onnx2keras工具的onnx_to_keras函数,以进行进一步的层映射。

让我们以全卷积网络架构为例来研究PyTorch ResNet18转换过程:

# import transferring tool from pytorch2keras.converter import pytorch_to_keras def converted_fully_convolutional_resnet18( input_tensor, pretrained_resnet=True, ): # define input tensor input_var = Variable(torch.FloatTensor(input_tensor)) # get PyTorch ResNet18 model model_to_transfer = FullyConvolutionalResnet18(pretrained=pretrained_resnet) model_to_transfer.eval() # convert PyTorch model to Keras model = pytorch_to_keras( model_to_transfer, input_var, [input_var.shape[-3:]], change_ordering=True, verbose=False, name_policy="keep", ) return model

现在我们可以比较PyTorch和TensorFlow FCN版本了。让我们看一下第一堆PyTorch FullyConvolutionalResnet18层。值得注意的是,我们使用了torchsummary工具来实现PyTorch和TensorFlow模型摘要的视觉一致性:

from torchsummary import summary summary(model_to_transfer, input_size=input_var.shape[-3:])

输出为:

---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 363, 960] 9,408 BatchNorm2d-2 [-1, 64, 363, 960] 128 ReLU-3 [-1, 64, 363, 960] 0 MaxPool2d-4 [-1, 64, 182, 480] 0 Conv2d-5 [-1, 64, 182, 480] 36,864 BatchNorm2d-6 [-1, 64, 182, 480] 128 ReLU-7 [-1, 64, 182, 480] 0 Conv2d-8 [-1, 64, 182, 480] 36,864 BatchNorm2d-9 [-1, 64, 182, 480] 128 ReLU-10 [-1, 64, 182, 480] 0 BasicBlock-11 [-1, 64, 182, 480] 0 Conv2d-12 [-1, 64, 182, 480] 36,864 BatchNorm2d-13 [-1, 64, 182, 480] 128 ReLU-14 [-1, 64, 182, 480] 0 Conv2d-15 [-1, 64, 182, 480] 36,864 BatchNorm2d-16 [-1, 64, 182, 480] 128 ReLU-17 [-1, 64, 182, 480] 0

使用pytorch_to_keras函数转换后获得的TensorFlow模型包含与初始PyTorch ResNet18模型相同的层,但TF特定的InputLayer和ZeroPadding2D作为填充参数包含在torch.nn.Conv2d中。

以下摘要是使用tf.keras.Model类的内置Keras summary方法生成的:

model.summary()

输出中的相应层已标记有用于PyTorch-TF映射的适当数字:

Layer (type) Output Shape Param # =============================================================== input_0 (InputLayer) [(None, 725, 1920, 3 0 _______________________________________________________________ 125_pad (ZeroPadding2D) (None, 731, 1926, 3) 0 _______________________________________________________________ 125 (Conv2D) (None, 363, 960, 64) 9408 1 _______________________________________________________________ 126 (BatchNormalization) (None, 363, 960, 64) 256 2 _______________________________________________________________ 127 (Activation) (None, 363, 960, 64) 0 3 _______________________________________________________________ 128_pad (ZeroPadding2D) (None, 365, 962, 64) 0 _______________________________________________________________ 128 (MaxPooling2D) (None, 182, 480, 64) 0 4 _______________________________________________________________ 129_pad (ZeroPadding2D) (None, 184, 482, 64) 0 _______________________________________________________________ 129 (Conv2D) (None, 182, 480, 64) 36864 5 _______________________________________________________________ 130 (BatchNormalization) (None, 182, 480, 64) 256 6 _______________________________________________________________ 131 (Activation) (None, 182, 480, 64) 0 7 _______________________________________________________________ 132_pad (ZeroPadding2D) (None, 184, 482, 64) 0 _______________________________________________________________ 132 (Conv2D) (None, 182, 480, 64) 36864 8 _______________________________________________________________ 133 (BatchNormalization) (None, 182, 480, 64) 256 9 _______________________________________________________________ 134 (Add) (None, 182, 480, 64) 0 _______________________________________________________________ 135 (Activation) (None, 182, 480, 64) 0 10 _______________________________________________________________ 136_pad (ZeroPadding2D) (None, 184, 482, 64) 0 _______________________________________________________________ 136 (Conv2D) (None, 182, 480, 64) 36864 12 _______________________________________________________________ 137 (BatchNormalization) (None, 182, 480, 64) 256 13 _______________________________________________________________ 138 (Activation) (None, 182, 480, 64) 0 14 _______________________________________________________________ 139_pad (ZeroPadding2D) (None, 184, 482, 64) 0 _______________________________________________________________ 139 (Conv2D) (None, 182, 480, 64) 36864 15 _______________________________________________________________ 140 (BatchNormalization) (None, 182, 480, 64) 256 16 _______________________________________________________________ 141 (Add) (None, 182, 480, 64) 0 _______________________________________________________________ 142 (Activation) (None, 182, 480, 64) 0 17 _______________________________________________________________ 143_pad (ZeroPadding2D) (None, 184, 482, 64) 0 _______________________________________________________________

以下方案部分介绍了两种版本的TensorFlow和PyTorch的FCN ResNet18块的可视化表示:

左侧为转换后的TensorFlow FCN ResNet18模型,右侧为初始的PyTorchFCN ResNet18模型。

使用Netron开源查看器生成模型图。它支持从ONNX,TensorFlow,Caffe,PyTorch等获得的各种模型格式。保存的模型图作为输入传递给Netron,Netron进一步生成详细的模型图。

3 转移模型结果

因此,我们将整个PyTorch FC ResNet-18模型及其权重转换为TensorFlow,将NCHW(批次大小,通道,高度,宽度)格式更改为NHWC,并带有change_ordering = True参数。

这样做是因为在PyTorch模型中,输入层的 shape 为3×725×1920,而在TensorFlow中,其输入层的 shape 已更改为725×1920×3,因为TF中的默认数据格式为NHWC。我们还应该记住,要获得与PyTorch (1, 1000, 3, 8)相同的预测shape,我们应该再次转置网络输出:

# NHWC: (1, 725, 1920, 3) predict_image = tf.expand_dims(image, 0) # NCHW: (1, 3, 725, 1920) image = np.transpose(tf.expand_dims(image, 0).numpy(), [0, 3, 1, 2]) # get transferred torch ResNet18 with pre-trained ImageNet weights model = converted_fully_convolutional_resnet18( input_tensor=image, pretrained_resnet=True, ) # Perform inference. # Instead of a 1×1000 vector, we will get a # 1×1000×n×m output ( i.e. a probability map # of size n × m for each 1000 class, # where n and m depend on the size of the image). preds = model.predict(predict_image) # NHWC: (1, 3, 8, 1000) back to NCHW: (1, 1000, 3, 8) preds = tf.transpose(preds, (0, 3, 1, 2)) preds = tf.nn.softmax(preds, axis=1)

要提到的另一点是图像预处理。我们记得在TF全卷积ResNet50中,应用了特殊的preprocess_input util函数。但是,在这里,为了转换为TF模型,我们使用与PyTorch FCN ResNet-18情况相同的归一化方法:

# transform input image: transform = Compose( [ Normalize( # subtract mean mean=(0.485, 0.456, 0.406), # divide by standard deviation std=(0.229, 0.224, 0.225), ), ], ) # apply image transformations, (725, 1920, 3) image = transform(image=image)["image"]

让我们来探讨结果:

Response map shape : (1, 1000, 3, 8) Predicted Class : Arabian camel, dromedary, Camelus dromedarius tf.Tensor(354, shape=(), dtype=int64)

预测的类别是正确的,让我们看一下响应图: 您可以看到,响应区域与之前的PyTorch FCN帖子中的响应区域相同。

转换后的TF FCN ResNet18结果



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3